#!/bin/env python
# encoding: utf-8
'''
 -- Sample Set Enrichment Analysis (SSEA) --

Assessment of enrichment in a ranked list of quantitative measurements 

@author:     mkiyer
@author:     yniknafs
        
@copyright:  2013 Michigan Center for Translational Pathology. All rights reserved.
        
@license:    GPL2

@contact:    mkiyer@umich.edu
@deffield    updated: Updated
'''
import sys
import os
import logging
import shutil
import subprocess
from collections import namedtuple

__all__ = []

PROFILE = 0

# local imports
from ssea.lib.config import Config 
from ssea.lib.base import JobStatus, SampleSet, chunk
from ssea.lib.countdata import BigCountMatrix
from ssea.lib.algo import ssea_map, ssea_reduce, ssea_serial

WORKER_FILE = 'workers.txt'
CLUSTER_MAP_SCRIPT = 'map.pbs'
CLUSTER_REDUCE_SCRIPT = 'reduce.pbs'
CLUSTER_MAPREDUCE_SCRIPT = 'run.pbs'
# Cluster walltime computation constants
SECS_PER_ITER_MAP = 3.0e-6
SECS_PER_ITER_REDUCE = 0.01
HOURS_PER_SEC = (1.0 / 3600.0)
PADDING_HOURS = 2.0
# Sample set info
SSInfo = namedtuple('SSInfo', ('id', 'dirname', 'name', 'desc', 'size'))
# worker process info
WorkerInfo = namedtuple('WorkerInfo', ('id', 'basename', 'startrow', 'endrow'))

def qsub(lines):
    '''function to submit a job using qsub'''
    p = subprocess.Popen("qsub", stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    p.stdin.write('\n'.join(lines))
    stdoutdata = p.communicate()[0]
    # check return code
    if p.returncode != 0:
        return None
    job_id = stdoutdata.strip()
    return job_id
  
def read_sample_sets(config):
    '''
    config: ssea.lib.Config object with SSEA options specified
    
    returns list of SampleSet objects
    '''
    # read sample sets
    logging.info("Reading sample sets")
    sample_sets = []
    for filename in config.smx_files:
        logging.debug("\tFile: %s" % (filename))
        sample_sets.extend(SampleSet.parse_smx(filename))
    for filename in config.smt_files:
        logging.debug("\tFile: %s" % (filename))
        sample_sets.extend(SampleSet.parse_smt(filename))
    for filename in config.json_files:
        logging.debug('\tFile: %s' % (filename))
        sample_sets.extend(SampleSet.parse_json(filename))
    logging.info("\tNumber of sample sets: %d" % (len(sample_sets)))
    filtered_sample_sets = []
    for sample_set in sample_sets:
        if ((config.smin > 0) and (len(sample_set) < config.smin)):
            logging.warning("\tsample set %s excluded because size %d < %d" %                              
                            (sample_set.name, len(sample_set), config.smin))
            continue        
        if ((config.smax > 0) and (len(sample_set) > config.smax)):
            logging.warning("\tsample set %s excluded because size %d > %d" % 
                            (sample_set.name, len(sample_set), config.smax))
            continue
        logging.debug("\tsample set %s size %d" % 
                      (sample_set.name, len(sample_set)))
        filtered_sample_sets.append(sample_set)
    logging.info("\tNumber of filtered sample sets: %d" % 
                 (len(filtered_sample_sets)))
    sample_sets = filtered_sample_sets
    return sample_sets

def read_sample_set_info(filename):
    '''
    Sample set information is stored as a text file the base output 
    directory of an SSEA run. This information can be helpful to 
    find sample sets because the folder names are currently assigned
    integer identifiers.
    '''
    ss_infos = []
    with open(filename) as fileh:
        fileh.next() # skip header
        for line in fileh:
            fields = line.strip().split('\t')
            ss = SSInfo(id=int(fields[0]),
                        dirname=fields[1],
                        name=fields[2],
                        desc=fields[3],
                        size=int(fields[4]))
            ss_infos.append(ss)
    return ss_infos

def write_sample_set_info(ss_infos, filename):
    with open(filename, 'w') as fileh:
        print >>fileh, '\t'.join(SSInfo._fields)
        for ss in ss_infos:
            fields = [str(ss.id), ss.dirname, ss.name, ss.desc, str(ss.size)]
            print >>fileh, '\t'.join(fields)

def read_worker_info(filename):
    worker_infos = []    
    with open(filename, 'r') as f:
        # search for worker index
        for line in f:
            fields = line.strip().split('\t')
            w = WorkerInfo(id=int(fields[0]),
                           basename=fields[1],
                           startrow=int(fields[2]),
                           endrow=int(fields[3]))
            worker_infos.append(w)
    return worker_infos

def write_worker_info(filename, num_jobs, num_processes):
    num_workers = 0
    max_chunk_size = 0
    with open(filename, 'w') as f:
        for startrow, endrow in chunk(num_jobs, num_processes):
            i = num_workers
            #logging.debug("Worker process %d range %d-%d (%d total rows)" % 
            #              (i, startrow, endrow, (endrow-startrow)))
            basename = 'w%d' % (i)
            fields = [i, basename, startrow, endrow]
            print >>f, '\t'.join(map(str,fields))
            num_workers += 1
            max_chunk_size = max(max_chunk_size, endrow-startrow)
    logging.debug("Worker processes: %d" % (num_workers))
    logging.debug("Max transcripts per worker: %d" % (max_chunk_size))

def ssea_setup(config):
    # create output directory
    if not os.path.exists(config.output_dir):
        logging.debug("Creating output directory '%s'" % (config.output_dir))
        os.makedirs(config.output_dir)
    # read/process weight matrix
    bm = BigCountMatrix.open(config.matrix_dir)
    if bm.size_factors is None:
        logging.debug("No size factors found in count matrix.. estimating")
        bm.estimate_size_factors('deseq')
    shape = bm.shape
    bm.close()
    # read/process sample sets
    sample_sets = read_sample_sets(config)
    # read current file if exists
    sample_set_info_file = os.path.join(config.output_dir, 
                                        Config.SAMPLE_SET_INFO_FILE)
    ss_infos = []
    ss_names = set()
    max_ss_id = 0
    if os.path.exists(sample_set_info_file):
        for ss_info in read_sample_set_info(sample_set_info_file):
            ss_infos.append(ss_info)
            ss_names.add(ss_info.name)
            max_ss_id = max(max_ss_id, ss_info.id)
    # read new sample sets
    for sample_set in sample_sets:
        logging.info("Sample Set: %s" % (sample_set.name))
        if sample_set.name in ss_names:
            logging.warning('Sample Set "%s" already exists in output '
                            'directory' % (sample_set.name))
            continue
        # create sample set directory
        ss_id = (max_ss_id + 1)
        sample_set_dirname = 'ss_%d' % (ss_id)
        sample_set_path = os.path.join(config.output_dir, sample_set_dirname)
        assert not os.path.exists(sample_set_path)
        max_ss_id += 1
        if not os.path.exists(sample_set_path):
            logging.debug("Creating sample set directory '%s'" % (sample_set_path))
            os.makedirs(sample_set_path)
        # create temp directory
        tmp_dir = os.path.join(sample_set_path, Config.TMP_DIR)
        if not os.path.exists(tmp_dir):
            logging.debug("Creating tmp directory '%s'" % (tmp_dir))
            os.makedirs(tmp_dir)
        # create log directory
        log_dir = os.path.join(sample_set_path, Config.LOG_DIR)
        if not os.path.exists(log_dir):
            logging.debug("Creating log directory '%s'" % (log_dir))
            os.makedirs(log_dir)         
        # write configuration file
        logging.debug("Writing configuration file")
        config_file = os.path.join(sample_set_path, Config.CONFIG_JSON_FILE) 
        with open(config_file, 'w') as fp:
            print >>fp, config.to_json()
        # write sample set file
        logging.debug("Writing sample set file")
        sample_set_file = os.path.join(sample_set_path, Config.SAMPLE_SET_JSON_FILE) 
        with open(sample_set_file, 'w') as fp:
            print >>fp, sample_set.to_json()
        # map work into chunks
        worker_file = os.path.join(sample_set_path, WORKER_FILE)
        write_worker_info(worker_file, 
                          num_jobs=shape[0], 
                          num_processes=config.num_processes)
        # add to sample set info list
        ss_info = SSInfo(id=ss_id, 
                         dirname=sample_set_dirname,
                         name=sample_set.name, 
                         desc=sample_set.desc, 
                         size=len(sample_set))
        ss_infos.append(ss_info)
        # mark job status as ready
        JobStatus.set(sample_set_path, JobStatus.READY)
    # write sample set info list
    write_sample_set_info(ss_infos, sample_set_info_file)
    return 0

def ssea_start(config, ssea_script):
    '''
    multiprocessing map/reduce implementation of ssea    
    '''
    # check output directory
    if not os.path.exists(config.output_dir):
        logging.error("Output directory '%s' not found" % 
                      (config.output_dir))
        return 1
    sample_set_info_file = os.path.join(config.output_dir, 
                                        Config.SAMPLE_SET_INFO_FILE)
    if not os.path.exists(sample_set_info_file):
        logging.error("Sample set info file '%s' not found" % 
                      (sample_set_info_file))
        return 1
    # read pbs configuration commands (cluster mode only)
    pbs_script_lines = []
    if config.pbs_script is not None:
        with open(config.pbs_script) as f:
            pbs_script_lines.extend(line.strip() for line in f)
    # read/process weight matrix
    bm = BigCountMatrix.open(config.matrix_dir)
    shape = bm.shape
    bm.close()
    # process each sample set
    for ss_info in read_sample_set_info(sample_set_info_file):
        logging.info('Sample Set: %s' % (ss_info.name))
        # get output directory
        sample_set_path = os.path.join(config.output_dir, ss_info.dirname)
        if not os.path.exists(sample_set_path):
            logging.warning('Could not find run folder for sample Set "%s"' % (ss_info.name))
            continue
        # check that job is not already running or done
        if JobStatus.get(sample_set_path) != JobStatus.READY:
            logging.info("Skipping Sample Set: %s DONE/BUSY" % (ss_info.name))
            continue
        # set job as busy (job will now be run)
        JobStatus.set(sample_set_path, JobStatus.BUSY)
        tmp_dir = os.path.join(sample_set_path, Config.TMP_DIR)
        log_dir = os.path.join(sample_set_path, Config.LOG_DIR)
        # read config
        config_json_file = os.path.join(sample_set_path, Config.CONFIG_JSON_FILE)
        config = Config.parse_json(config_json_file)
        # read sample sets
        sample_set_json_file = os.path.join(sample_set_path, Config.SAMPLE_SET_JSON_FILE)
        sample_set = SampleSet.parse_json(sample_set_json_file)[0]
        # read worker information
        worker_file = os.path.join(sample_set_path, WORKER_FILE)
        worker_infos = read_worker_info(worker_file)
        if config.cluster is None:        
            worker_basenames = [os.path.join(tmp_dir, w.basename) for w in worker_infos]
            worker_chunks = [(w.startrow, w.endrow) for w in worker_infos]
            # output files
            logging.info("Running SSEA map step with %d parallel processes " % 
                         (len(worker_infos)))
            # map step
            ssea_map(config, sample_set, worker_basenames, worker_chunks)
            # reduce step
            logging.info("Running SSEA reduce step")
            hists_file = os.path.join(sample_set_path, Config.OUTPUT_HISTS_FILE)
            json_file = os.path.join(sample_set_path, Config.RESULTS_JSON_FILE)
            ssea_reduce(worker_basenames, json_file, hists_file)
            # cleanup
            if os.path.exists(tmp_dir):
                shutil.rmtree(tmp_dir)
            # mark done
            logging.debug("Sample Set '%s' Done." % (sample_set.name))
            JobStatus.set(sample_set_path, JobStatus.DONE)
        else:
            max_chunk_size = max((w.endrow - w.startrow) for w in worker_infos)
            num_workers = len(worker_infos)
            sample_set_abspath = os.path.abspath(sample_set_path)
            # calculate job walltime
            num_tests = (max_chunk_size * shape[1] * 
                         (1 + config.perms + config.resampling_iterations))
            hours = (num_tests * SECS_PER_ITER_MAP * HOURS_PER_SEC)
            padded_hours = int(hours + PADDING_HOURS) # padding
            logging.debug('Map step walltime calculation: %d rows/chunk %d '
                          'samples %d perms %d resamplings' %
                          (max_chunk_size, shape[1], config.perms, 
                           config.resampling_iterations))
            logging.debug('\tAt %fs per iteration this will take %f hours' % 
                          (SECS_PER_ITER_MAP, hours))
            logging.debug('\tPadded walltime is %dh' % (padded_hours))
            # make pbs resource fields
            resource_fields = []
            resource_fields.append('walltime=%d:00:00' % (padded_hours))
            resource_fields.append('nodes=1:ppn=1')
            resource_fields.append('pmem=1024mb')
            # write map script
            lines = ['#!/bin/bash']
            lines.extend(pbs_script_lines)
            lines.append('#PBS -N %s' % (ss_info.dirname))
            lines.append('#PBS -l %s' % (','.join(resource_fields)))
            lines.append('#PBS -t 0-%d' % (num_workers-1))
            lines.append('#PBS -V')
            lines.append("#PBS -o %s" % (os.path.join(log_dir, 'map.stdout')))
            lines.append("#PBS -e %s" % (os.path.join(log_dir, 'map.stderr')))
            # command line args
            args = ['python', ssea_script, '-v', '--cluster-map', '-o', 
                    sample_set_abspath]
            lines.append(' '.join(args))
            lines.append('')
            # write script to file
            script_file = os.path.join(sample_set_path, CLUSTER_MAP_SCRIPT)
            with open(script_file, 'w') as f:
                f.write('\n'.join(lines))
            # try to submit map job
            job_array_id = qsub(lines)
            if job_array_id is None:
                logging.error('Failed to submit job array, skipping job')
                JobStatus.set(sample_set_path, JobStatus.READY)
                continue
            logging.info('Submitted map job array id=%s' % (job_array_id))
            # calculate walltime for reduce step
            num_results = shape[0]
            hours = (num_results * SECS_PER_ITER_REDUCE * HOURS_PER_SEC)
            padded_hours = int(hours + PADDING_HOURS) # padding
            logging.debug('Reduce walltime calculation: %d rows' % (shape[0])) 
            logging.debug('\tAt %fs per iteration this will take %f hours' % 
                          (SECS_PER_ITER_REDUCE, hours))
            logging.debug('\tPadded reduce walltime is %dh' % (padded_hours))
            # make pbs resource fields
            resource_fields = []
            resource_fields.append('walltime=%d:00:00' % (padded_hours))
            resource_fields.append('nodes=1:ppn=1')
            resource_fields.append('pmem=3750mb')
            lines = ['#!/bin/bash']
            lines.extend(pbs_script_lines)
            lines.append('#PBS -N %s' % (ss_info.dirname))
            lines.append('#PBS -l %s' % (','.join(resource_fields)))
            lines.append('#PBS -V')
            lines.append("#PBS -o %s" % (os.path.join(log_dir, 'reduce.stdout')))
            lines.append("#PBS -e %s" % (os.path.join(log_dir, 'reduce.stderr')))
            lines.append("#PBS -W depend=afterokarray:%s" % (job_array_id))
            # command line args
            args = ['python', ssea_script, '-v', '--cluster-reduce', '-o', 
                    sample_set_abspath]
            lines.append(' '.join(args))
            lines.append('')
            # write script to file
            script_file = os.path.join(sample_set_path, CLUSTER_REDUCE_SCRIPT)
            with open(script_file, 'w') as f:
                f.write('\n'.join(lines))
            # submit reduce job
            job_id = qsub(lines)
            if job_id is None:
                logging.error('Failed to submit reduce job, skipping job')
                JobStatus.set(sample_set_path, JobStatus.READY)
                continue
            logging.info('Submitted reduce job id=%s' % (job_id))
    return 0

def ssea_cluster_map(output_dir):
    # open JSON files in output directory
    # read config
    config_json_file = os.path.join(output_dir, Config.CONFIG_JSON_FILE)
    config = Config.parse_json(config_json_file)
    # read sample sets
    sample_set_json_file = os.path.join(output_dir, Config.SAMPLE_SET_JSON_FILE)
    sample_set = SampleSet.parse_json(sample_set_json_file)[0]
    # the PBS_ARRAY_INDEX environment variable should be set
    worker_index = os.environ.get('PBS_ARRAYID', None)
    if worker_index is None:
        logging.error('Could not find "PBS_ARRAYID" environment variable')
        return 1
    worker_index = int(worker_index)
    worker_file = os.path.join(output_dir, WORKER_FILE)
    worker_infos = read_worker_info(worker_file)
    worker = None
    for w in worker_infos:
        if w.id == worker_index:
            worker = w
            break
    if worker is None:
        logging.error("Worker index '%d' not found in worker file" % (worker_index))
        return 1
    # run SSEA algorithm
    worker_basename = os.path.join(output_dir, Config.TMP_DIR, worker.basename)
    return ssea_serial(config, sample_set, worker_basename, worker.startrow, 
                       worker.endrow)

def ssea_cluster_reduce(output_dir):
    # read worker information
    worker_file = os.path.join(output_dir, WORKER_FILE)
    worker_infos = read_worker_info(worker_file)
    worker_basenames = [os.path.join(output_dir, Config.TMP_DIR, w.basename)
                        for w in worker_infos]
    # reduce step
    logging.info("Computing FDR q values")
    hists_file = os.path.join(output_dir, Config.OUTPUT_HISTS_FILE)
    json_file = os.path.join(output_dir, Config.RESULTS_JSON_FILE)
    # open JSON files in output directory
    ssea_reduce(worker_basenames, json_file, hists_file)
    # cleanup
    tmp_dir = os.path.join(output_dir, Config.TMP_DIR)
    if os.path.exists(tmp_dir):
        shutil.rmtree(tmp_dir)
    # mark done
    JobStatus.set(output_dir, JobStatus.DONE)
    logging.debug("Done.")
    return 0

def main():
    # create instance of run configuration
    config = Config.parse_args()
    # show configuration
    config.log()
    # decide how to run ssea
    if ((config.cluster is None) or
        (config.cluster == 'setup')):
        # setup run
        ssea_setup(config)
        # get absolute path to this script
        ssea_script = os.path.abspath(sys.argv[0])
        retcode = ssea_start(config, ssea_script)
    elif config.cluster == 'map':
        retcode = ssea_cluster_map(config.output_dir)
    elif config.cluster == 'reduce':
        retcode = ssea_cluster_reduce(config.output_dir)
    return retcode

if __name__ == "__main__":
    if PROFILE:
        import cProfile
        import pstats
        profile_filename = '_profile.bin'
        cProfile.run('main()', profile_filename)
        statsfile = open("profile_stats.txt", "wb")
        p = pstats.Stats(profile_filename, stream=statsfile)
        stats = p.strip_dirs().sort_stats('cumulative')
        stats.print_stats()
        statsfile.close()
        sys.exit(0)
    sys.exit(main())